/* OOQP                                                               *
 * Authors: E. Michael Gertz, Stephen J. Wright                       *
 * (C) 2001 University of Chicago. See Copyright Notification in OOQP */

/* PARDISO Solver Implemented by J.Currie September 2011 */

#include "PardisoSolver.h"
#include "SparseStorage.h"
#include "SparseSymMatrix.h"
#include "SimpleVector.h"
#include "SimpleVectorHandle.h"

PardisoSolver::PardisoSolver( SparseSymMatrix * ssm )
{
	//Assign Input Matrix	
	SpReferTo( mMat, ssm );

	//Setup
	mtype		= -2;	// symmetric indefinite
	nrhs		= 1;	// Only ever 1 rhs
	maxfct		= 1;	// maximum numerical factorizations to keep
	mnum		= 1;	// which factorization from above to use
	msglvl		= 0;	// do not print PARDISO info
	error		= 0;	// initialize error
	is_init		= 0;	// not initialized

	//Initialize Internal Memory
	for (int i = 0; i < 64; i++) {
		pt[i] = 0;
	}

	//Setup Pardiso Options
	for (int i = 0; i < 64; i++) {
		iparm[i] = 0;
	}
	iparm[0] = 1; // No solver default
	iparm[1] = 2; // Fill-in reordering from METIS
	iparm[3] = 0; // Not using CGS iteration
	iparm[4] = 0; // No user fill-in reducing permutation
	iparm[5] = 1; // Write solution back into b
	iparm[7] = 2; // Max numbers of iterative refinement steps
	iparm[9] = 8; // Perturb the pivot elements with 1E-8 (sym indef default)
	iparm[10] = 0; // Don't use nonsymmetric permutation and scaling
	iparm[11] = 0; // Solve normal Ax = b
	iparm[12] = 0; // Maximum weighted matching algorithm is switched-off (default for symmetric).
	iparm[17] = 0; // Don't Output: Number of nonzeros in the factor LU 
	iparm[18] = 0; // Don't Output: Mflops for LU factorization 
	iparm[20] = 1; // Use Bunch & Kaufman pivoting
	iparm[34] = -1; // Zero based indexing
}

void PardisoSolver::firstCall()
{
	int mm, nnz;

	//Setup now we have a real matrix	
	mMat->getSize(mm,n);	// Get Size
	nnz = mMat->numberOfNonZeros();
	//Setup Vectors Memory
	x = new double[n];		// solution vector
	inc = new int[nnz];		//work memory
	lmem = new int[nnz];	//expanded col vector
	a = new double[nnz];	//Sparse Transposed Values
	ja = new MKL_INT[nnz];	//Sparse Indicies
	ia = new MKL_INT[n+1];
	diagInd = new int[n];		//Indicies of diagonal elements
	diagVec = new SimpleVector(n);	//Vector to hold diagonal elements

	//Update Sparse Vectors
	updateSparseVec();

	// Reordering and Symbolic Factorization. This step also allocates
	// all memory that is necessary for the factorization.
	phase = 11;
	PARDISO (pt, &maxfct, &mnum, &mtype, &phase,
		&n, a, ia, ja, &idum, &nrhs,
		iparm, &msglvl, &ddum, &ddum, &error);
	if (error != 0) {
		printf("\nERROR during symbolic factorization: %d", error);
	}

	is_init = 1; //now initialized
}  

void PardisoSolver::diagonalChanged( int /* idiag */, int /* extent */ )
{
  this->matrixChanged();
}

void PardisoSolver::matrixChanged()
{
	// Initialize if first call!
	if( !is_init ) 
		this->firstCall();
	//Otherwise only diagonals get updated (no transpose neccesary)
	else
	{
		mMat->getDiagonal(*diagVec);
		for(int i = 0; i < n; i++)
			a[diagInd[i]] = (*diagVec)[i];
	}

	//Numerical Factorization
	phase = 22;
	PARDISO (pt, &maxfct, &mnum, &mtype, &phase,
		&n, a, ia, ja, &idum, &nrhs,
		iparm, &msglvl, &ddum, &ddum, &error);
	if (error != 0) {
		printf("\nERROR during numerical factorization: %d", error);
	}
}


void PardisoSolver::solve( OoqpVector& rhs_in )
{
	SimpleVector & rhs = dynamic_cast<SimpleVector &>(rhs_in);

	double * b = rhs.elements();

	//Back Substitution and Iterative Refinement
	phase = 33;
	PARDISO (pt, &maxfct, &mnum, &mtype, &phase,
		&n, a, ia, ja, &idum, &nrhs,
		iparm, &msglvl, b, x, &error);
	if (error != 0) {
		printf("\nERROR during solution: %d", error);
	}
}

void PardisoSolver::updateSparseVec() 
{
	int dInd = 0, ind, index, i, j, no;
	//Sizes
	int nnz = mMat->numberOfNonZeros();
	//Get Sparse Indices
	int *col = mMat->jcolM();
	int *row = mMat->krowM();
	double *vals = mMat->M();
	//Reset Work Memory
	memset(inc,0,n*sizeof(int));
	memset(ia,0,(n+1)*sizeof(int));

	//Transpose Sparse Matrix (lower tri -> upper tri)
	for(i = 0; i < nnz; i++)
		ia[col[i]+1]++;	//sum each 0, 1, 2, etc index 

	for(i = 2; i <= n; i++)
		ia[i] += ia[i-1]; //cumsum to determine new 'Jc'

	ind = 0;
	for(i = 0; i < n; i++) { //build full missing triple
		no = row[i+1]-row[i];
		for(j = 0; j < no; j++)
			lmem[ind++] = i;
	}
	
	for(i = 0; i < nnz; i++)
	{
		ind = col[i];
		index = ia[ind] + inc[ind]++;
		ja[index] = lmem[i]; //new 'Ir' from generated sparse triple
		a[index] = vals[i];	 //transpose mMat	
		//Save Transpoed Diagonal Indices
		if(lmem[i] == col[i])
			diagInd[dInd++] = index;
	}
}

PardisoSolver::~PardisoSolver()
{
	//Free PARDISO Memory
	phase = -1;
	PARDISO (pt, &maxfct, &mnum, &mtype, &phase,
		&n, &ddum, ia, ja, &idum, &nrhs,
		iparm, &msglvl, &ddum, &ddum, &error);

	delete [] x;
	delete [] inc;
	delete [] lmem;
	delete [] a;
	delete [] ja;
	delete [] ia;
	delete [] diagInd;
	delete diagVec;
}